{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import math\n",
    "import torch\n",
    "from torch.distributions import Normal, MultivariateNormal\n",
    "from torch import matmul"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Univariate Normal Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import Normal\n",
    "\n",
    "# Set the parameters of the distribution\n",
    "mu = torch.tensor([0.0], dtype=torch.float)\n",
    "sigma = torch.tensor([5.0], dtype=torch.float)\n",
    "\n",
    "# Instantiate the univariate normal distribution\n",
    "uvn_dist = Normal(mu, sigma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -2.528\n",
      "Raw eval Log Prob: -2.528\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([0.0], dtype=torch.float)\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, mu, sigma):\n",
    "    K = 1 / (math.sqrt(2 * math.pi) * sigma)\n",
    "    E = math.exp( -1 * (X - mu) ** 2 * (1 / (2 * sigma ** 2)))\n",
    "    return torch.log(K * E)\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = uvn_dist.log_prob(X)\n",
    "print(\"Log Prob: {:.3f}\".format(log_prob[0]))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, mu, sigma)\n",
    "print(\"Raw eval Log Prob: {:.3f}\".format(raw_eval_log_prob[0]))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of samples to draw\n",
    "num_samples = 100000\n",
    "\n",
    "# Draw samples\n",
    "samples = uvn_dist.sample([num_samples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample Mean: 0.022\n",
      "Dist Mean: 0.000\n",
      "Sample Variance: 24.858\n",
      "Dist Variance: 25.000\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples\n",
    "sample_mean = samples.mean()\n",
    "print(\"Sample Mean: {:.3f}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = uvn_dist.mean\n",
    "print(\"Dist Mean: {:.3f}\".format(dist_mean[0]))\n",
    "\n",
    "# As expected, the two means approximately match\n",
    "assert torch.isclose(sample_mean, dist_mean, atol=0.3)\n",
    "\n",
    "# The variance obtained from the samples\n",
    "sample_var = uvn_dist.sample([num_samples]).var()\n",
    "print(\"Sample Variance: {:.3f}\".format(sample_var))\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = uvn_dist.variance\n",
    "print(\"Dist Variance: {:.3f}\".format(dist_var[0]))\n",
    "\n",
    "# As expected, the two variances approximately match\n",
    "assert torch.isclose(sample_var, dist_var, atol=0.3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Visualization\n",
    "\n",
    "Here we allow the user to set different values for the mean and variance of a univariate normal distribution and visualise the resulting distribution. \n",
    "Specifically, notice that changing the mean does not change the shape of the distribution. It just varies where the distribution peaks. Changing the variance causes the distribution to either become more diffuse / peaked.\n",
    "\n",
    "Note: In order to run this section, please download the notebook. Interactive snippets do not work online. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de3cb8d299a64806afdd547197d566e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02e4f2d4e5594c4eb397c858d426f7ae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.0, description='mu', max=40.0, min=-40.0, step=0.5), FloatSlider(val…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "import ipywidgets as widgets\n",
    "from ipywidgets import interact, interact_manual\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "ax.set_title(\"Univariate Normal Distribution\")\n",
    "ax.set_ylabel(\"P(X)\")\n",
    "ax.set_xlabel(\"X\")\n",
    "\n",
    "\n",
    "@interact\n",
    "def plot_univariate_normal(mu=(-40, 40, 0.5), sigma=(4, 30, 0.5)):\n",
    "    x = np.linspace(mu - 3*sigma, mu + 3*sigma, 1000)\n",
    "    [l.remove() for l in ax.lines]\n",
    "    uvn_dist = Normal(mu, sigma)\n",
    "    pdf = uvn_dist.log_prob(torch.from_numpy(x)).exp()\n",
    "    ax.set_xlim(-75, 75)\n",
    "    ax.set_ylim(0, 0.1)\n",
    "    ax.plot(x, pdf, 'green')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multivariate Normal Distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.distributions import MultivariateNormal\n",
    "\n",
    "# Set the parameters of the distribution\n",
    "mu = torch.tensor([0.0, 0.0], dtype=torch.float)\n",
    "C = torch.tensor([[5.0, 0.0], [0.0, 5.0]], dtype=torch.float)\n",
    "\n",
    "# Instantiate the multivariate normal distribution\n",
    "mvn_dist = MultivariateNormal(mu, C)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Log Prob: -3.447\n",
      "Raw eval Log Prob: -3.447\n"
     ]
    }
   ],
   "source": [
    "# Instantiate single point test dataset\n",
    "X = torch.tensor([0.0, 0.0], dtype=torch.float)\n",
    "\n",
    "# Function to evaluate log prob using math formula\n",
    "def raw_eval(X, mu, C):\n",
    "    K = (1 / (2 * math.pi * torch.sqrt(C.det())))\n",
    "    X_minus_mu = (X - mu).reshape(-1, 1)\n",
    "    E1 = torch.matmul(X_minus_mu.T, C.inverse())\n",
    "    E = torch.exp(-1 / 2. * torch.matmul(E1, X_minus_mu))\n",
    "    return torch.log(K * E)\n",
    "\n",
    "# Evaluate log-prob using PyTorch distributions function call\n",
    "log_prob = mvn_dist.log_prob(X)\n",
    "print(\"Log Prob: {:.3f}\".format(log_prob))\n",
    "\n",
    "# Evaluate log-prob using formula\n",
    "raw_eval_log_prob = raw_eval(X, mu, C)\n",
    "print(\"Raw eval Log Prob: {:.3f}\".format(raw_eval_log_prob[0][0]))\n",
    "\n",
    "assert torch.isclose(log_prob, raw_eval_log_prob, atol=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of samples to draw\n",
    "num_samples = 100000\n",
    "\n",
    "# Draw samples\n",
    "samples = mvn_dist.sample([num_samples])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample Mean: tensor([-0.0001, -0.0004])\n",
      "Dist Mean: tensor([0., 0.])\n",
      "Sample Variance: tensor([4.9463, 5.0443])\n",
      "Dist Variance: tensor([5., 5.])\n"
     ]
    }
   ],
   "source": [
    "# The mean obtained from the samples\n",
    "sample_mean = samples.mean(axis=0)\n",
    "print(\"Sample Mean: {}\".format(sample_mean))\n",
    "\n",
    "# The mean of the distribution from Pytorch\n",
    "dist_mean = mvn_dist.mean\n",
    "print(\"Dist Mean: {}\".format(dist_mean))\n",
    "\n",
    "# As expected, the two means approximately match\n",
    "assert torch.allclose(sample_mean, dist_mean, atol=1e-1)\n",
    "\n",
    "# The variance obtained from the samples\n",
    "sample_var = mvn_dist.sample([num_samples]).var(axis=0)\n",
    "print(\"Sample Variance: {}\".format(sample_var))\n",
    "\n",
    "# The variance of the distribution from Pytorch\n",
    "dist_var = mvn_dist.variance\n",
    "print(\"Dist Variance: {}\".format(dist_var))\n",
    "\n",
    "# As expected, the two variances approximately match\n",
    "assert torch.allclose(sample_var, dist_var, atol=1e-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Visualization\n",
    "\n",
    "Here we allow the user to set different values for the means and covariance matrix of a 2D Normal distribution and visualise the resulting distribution. \n",
    "\n",
    "Specifically, notice that changing the mean does not change the shape of the distribution. It just varies where the distribution peaks. Changing $\\mu_{0}$ shifts the center along the X axis. Similarly changing $\\mu_{1}$ shifts the center along the Y-axis\n",
    "\n",
    "While providing values for the covariance matrix, we should ensure that the matrix is not singular.\n",
    "\n",
    "Note: In order to run this section, please download the notebook. Interactive snippets do not work online. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f03709d6cf434ae5907c4f7f77596750",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/ananya.h.a/.virtualenvs/dl_book/lib/python3.7/site-packages/ipykernel_launcher.py:6: UserWarning: Requested projection is different from current axis projection, creating new axis with requested projection.\n",
      "  \n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d29d7d454de140fe8965766c369b28c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.0, description='mu_0', max=5.0, min=-5.0, step=0.25), FloatSlider(va…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from mpl_toolkits.mplot3d import Axes3D # <--- This is important for 3d plotting \n",
    "from matplotlib import cm\n",
    "\n",
    "fig_1, ax_1 = plt.subplots(nrows=1, ncols=1)\n",
    "ax_1.set_title(\"Bivariate Normal Distribution\")\n",
    "ax_1 = fig_1.gca(projection='3d')\n",
    "\n",
    "\n",
    "@interact\n",
    "def plot_2d_normal(\n",
    "    mu_0=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    mu_1=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    sigma_00=widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0),\n",
    "    sigma_01=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    sigma_11 =widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0)):\n",
    "    \n",
    "    def _reset_plot(ax):\n",
    "        ax.clear()\n",
    "        ax.set_ylabel(\"Y\")\n",
    "        ax.set_xlabel(\"X\")\n",
    "        ax.set_zlabel(\"P(X,Y)\")\n",
    "        \n",
    "    X = np.linspace(-10, 10, 1000)\n",
    "    Y = np.linspace(-10, 10, 1000)\n",
    "    X, Y = np.meshgrid(X, Y)\n",
    "    XY = np.stack((X, Y), axis=2)\n",
    "    mu = np.array([mu_0, mu_1])\n",
    "    sigma_10 = sigma_01 # Covariance matrix is symmetric\n",
    "\n",
    "    C = np.array([[sigma_00, sigma_01], [sigma_10, sigma_11]])\n",
    "    try:\n",
    "        mvn_dist = MultivariateNormal(torch.from_numpy(mu), torch.from_numpy(C))\n",
    "        Z = mvn_dist.log_prob(torch.from_numpy(XY)).exp().numpy()\n",
    "        # Plot the surface.\n",
    "        _reset_plot(ax_1)\n",
    "        ax_1.plot_surface(X, Y, Z, cmap=cm.coolwarm,\n",
    "                       linewidth=0, antialiased=False)\n",
    "    except RuntimeError:\n",
    "        print(\"Error!: Covariance matrix cannot be singular!\")\n",
    "        ax_1.clear()\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Contour plots obtained from Bivariate Normal Distributions\n",
    "\n",
    "For any Bivariate normal distribution, the plot of the surface $p (x, y)$ against $(x, y)$ looks like a bell in 3D space. The shape of the bell’s base, on the $(x, y)$ plane, is governed by the 2x2 matrix $\\sum$\n",
    "\n",
    "If $\\sum$ is a diagonal matrix with equal diagonal elements, the bell is symmetric in all directions,\n",
    "its base is circular\n",
    "\n",
    "If $\\sum$ is a diagonal matrix with unequal diagonal elements, the base of the bell is elliptical.\n",
    "The axes of the ellipse are aligned with coordinate axes.\n",
    "\n",
    "For general $\\sum$ matrix the base of the bell is elliptical. The axes of the ellipse are not necessarily\n",
    "aligned with coordinate axes.\n",
    "\n",
    "Observe the following as you interact with the visualization\n",
    "- When $\\mu_{0}$ increases, the base of the bell shifts along the X-axis.\n",
    "- When $\\mu_{1}$ increases, the base of the bell shifts along the Y-axis.\n",
    "- When $\\sigma_{00}$ increases, the spread along the X-axis increases.\n",
    "- When $\\sigma_{11}$ increases, the spread along the Y-axis increases.\n",
    "\n",
    "Note: In order to run this section, please download the notebook. Interactive snippets do not work online. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39426f6c8217481aa7300151692cfa52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e670dde66afa49d3a33a1a92d57a80eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "interactive(children=(FloatSlider(value=0.0, description='mu_0', max=5.0, min=-5.0, step=0.25), FloatSlider(va…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig_2, ax_2 = plt.subplots(nrows=1, ncols=1)\n",
    "\n",
    "\n",
    "@interact\n",
    "def plot_2d_normal_contour(\n",
    "    mu_0=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    mu_1=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    sigma_00=widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0),\n",
    "    sigma_01=widgets.FloatSlider(min=-5, max=5, step=0.25, value=0.0),\n",
    "    sigma_11 =widgets.FloatSlider(min=0, max=5, step=0.25, value=1.0)):\n",
    "\n",
    "    def _reset_plot(ax):\n",
    "        ax.clear()\n",
    "        ax_2.set_title(\"Base of the Bivariate Normal Distribution\")\n",
    "        \n",
    "    X = np.linspace(-10, 10, 1000)\n",
    "    Y = np.linspace(-10, 10, 1000)\n",
    "    X, Y = np.meshgrid(X, Y)\n",
    "    XY = np.stack((X, Y), axis=2)\n",
    "    mu = np.array([mu_0, mu_1])\n",
    "    sigma_10 = sigma_01 # Covariance matrix is symmetrical\n",
    "\n",
    "    C = np.array([[sigma_00, sigma_01], [sigma_10, sigma_11]])\n",
    "    try:\n",
    "        mvn_dist = MultivariateNormal(torch.from_numpy(mu), torch.from_numpy(C))\n",
    "        Z = mvn_dist.log_prob(torch.from_numpy(XY)).exp()\n",
    "        _reset_plot(ax_2)\n",
    "        ax_2.contour(Z)\n",
    "    except RuntimeError:\n",
    "        print(\"Error!: The covariance matrix must not be singular\")\n",
    "        ax_2.clear()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
